# Copyright (c) Microsoft. All rights reserved.

"""Tests for AGUIChatClient."""

import json
from collections.abc import AsyncGenerator, AsyncIterable, MutableSequence
from typing import Any

from agent_framework import (
    ChatMessage,
    ChatOptions,
    ChatResponseUpdate,
    FunctionCallContent,
    Role,
    TextContent,
    ai_function,
)
from agent_framework._types import ChatResponse
from pytest import MonkeyPatch

from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent
from agent_framework_ag_ui._http_service import AGUIHttpService


class TestableAGUIChatClient(AGUIChatClient):
    """Testable wrapper exposing protected helpers."""

    @property
    def http_service(self) -> AGUIHttpService:
        """Expose http service for monkeypatching."""
        return self._http_service

    def extract_state_from_messages(
        self, messages: list[ChatMessage]
    ) -> tuple[list[ChatMessage], dict[str, Any] | None]:
        """Expose state extraction helper."""
        return self._extract_state_from_messages(messages)

    def convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[dict[str, Any]]:
        """Expose message conversion helper."""
        return self._convert_messages_to_agui_format(messages)

    def get_thread_id(self, chat_options: ChatOptions) -> str:
        """Expose thread id helper."""
        return self._get_thread_id(chat_options)

    async def inner_get_streaming_response(
        self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
    ) -> AsyncIterable[ChatResponseUpdate]:
        """Proxy to protected streaming call."""
        async for update in self._inner_get_streaming_response(messages=messages, chat_options=chat_options):
            yield update

    async def inner_get_response(
        self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
    ) -> ChatResponse:
        """Proxy to protected response call."""
        return await self._inner_get_response(messages=messages, chat_options=chat_options)


class TestAGUIChatClient:
    """Test suite for AGUIChatClient."""

    async def test_client_initialization(self) -> None:
        """Test client initialization."""
        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")

        assert client.http_service is not None
        assert client.http_service.endpoint.startswith("http://localhost:8888")

    async def test_client_context_manager(self) -> None:
        """Test client as async context manager."""
        async with TestableAGUIChatClient(endpoint="http://localhost:8888/") as client:
            assert client is not None

    async def test_extract_state_from_messages_no_state(self) -> None:
        """Test state extraction when no state is present."""
        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        messages = [
            ChatMessage(role="user", text="Hello"),
            ChatMessage(role="assistant", text="Hi there"),
        ]

        result_messages, state = client.extract_state_from_messages(messages)

        assert result_messages == messages
        assert state is None

    async def test_extract_state_from_messages_with_state(self) -> None:
        """Test state extraction from last message."""
        import base64

        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")

        state_data = {"key": "value", "count": 42}
        state_json = json.dumps(state_data)
        state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")

        from agent_framework import DataContent

        messages = [
            ChatMessage(role="user", text="Hello"),
            ChatMessage(
                role="user",
                contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
            ),
        ]

        result_messages, state = client.extract_state_from_messages(messages)

        assert len(result_messages) == 1
        assert result_messages[0].text == "Hello"
        assert state == state_data

    async def test_extract_state_invalid_json(self) -> None:
        """Test state extraction with invalid JSON."""
        import base64

        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")

        invalid_json = "not valid json"
        state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8")

        from agent_framework import DataContent

        messages = [
            ChatMessage(
                role="user",
                contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
            ),
        ]

        result_messages, state = client.extract_state_from_messages(messages)

        assert result_messages == messages
        assert state is None

    async def test_convert_messages_to_agui_format(self) -> None:
        """Test message conversion to AG-UI format."""
        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        messages = [
            ChatMessage(role=Role.USER, text="What is the weather?"),
            ChatMessage(role=Role.ASSISTANT, text="Let me check.", message_id="msg_123"),
        ]

        agui_messages = client.convert_messages_to_agui_format(messages)

        assert len(agui_messages) == 2
        assert agui_messages[0]["role"] == "user"
        assert agui_messages[0]["content"] == "What is the weather?"
        assert agui_messages[1]["role"] == "assistant"
        assert agui_messages[1]["content"] == "Let me check."
        assert agui_messages[1]["id"] == "msg_123"

    async def test_get_thread_id_from_metadata(self) -> None:
        """Test thread ID extraction from metadata."""
        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"})

        thread_id = client.get_thread_id(chat_options)

        assert thread_id == "existing_thread_123"

    async def test_get_thread_id_generation(self) -> None:
        """Test automatic thread ID generation."""
        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        chat_options = ChatOptions()

        thread_id = client.get_thread_id(chat_options)

        assert thread_id.startswith("thread_")
        assert len(thread_id) > 7

    async def test_get_streaming_response(self, monkeypatch: MonkeyPatch) -> None:
        """Test streaming response method."""
        mock_events = [
            {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
            {"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"},
            {"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"},
            {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
        ]

        async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
            for event in mock_events:
                yield event

        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        monkeypatch.setattr(client.http_service, "post_run", mock_post_run)

        messages = [ChatMessage(role="user", text="Test message")]
        chat_options = ChatOptions()

        updates: list[ChatResponseUpdate] = []
        async for update in client.inner_get_streaming_response(messages=messages, chat_options=chat_options):
            updates.append(update)

        assert len(updates) == 4
        assert updates[0].additional_properties is not None
        assert updates[0].additional_properties["thread_id"] == "thread_1"

        first_content = updates[1].contents[0]
        second_content = updates[2].contents[0]
        assert isinstance(first_content, TextContent)
        assert isinstance(second_content, TextContent)
        assert first_content.text == "Hello"
        assert second_content.text == " world"

    async def test_get_response_non_streaming(self, monkeypatch: MonkeyPatch) -> None:
        """Test non-streaming response method."""
        mock_events = [
            {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
            {"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Complete response"},
            {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
        ]

        async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
            for event in mock_events:
                yield event

        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        monkeypatch.setattr(client.http_service, "post_run", mock_post_run)

        messages = [ChatMessage(role="user", text="Test message")]
        chat_options = ChatOptions()

        response = await client.inner_get_response(messages=messages, chat_options=chat_options)

        assert response is not None
        assert len(response.messages) > 0
        assert "Complete response" in response.text

    async def test_tool_handling(self, monkeypatch: MonkeyPatch) -> None:
        """Test that client tool metadata is sent to server.

        Client tool metadata (name, description, schema) is sent to server for planning.
        When server requests a client function, @use_function_invocation decorator
        intercepts and executes it locally. This matches .NET AG-UI implementation.
        """
        from agent_framework import ai_function

        @ai_function
        def test_tool(param: str) -> str:
            """Test tool."""
            return "result"

        mock_events = [
            {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
            {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
        ]

        async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
            # Client tool metadata should be sent to server
            tools: list[dict[str, Any]] | None = kwargs.get("tools")
            assert tools is not None
            assert len(tools) == 1
            tool_entry = tools[0]
            assert tool_entry["name"] == "test_tool"
            assert tool_entry["description"] == "Test tool."
            assert "parameters" in tool_entry
            for event in mock_events:
                yield event

        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        monkeypatch.setattr(client.http_service, "post_run", mock_post_run)

        messages = [ChatMessage(role="user", text="Test with tools")]
        chat_options = ChatOptions(tools=[test_tool])

        response = await client.inner_get_response(messages=messages, chat_options=chat_options)

        assert response is not None

    async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch: MonkeyPatch) -> None:
        """Ensure server-side tool calls are exposed as FunctionCallContent after processing."""

        mock_events = [
            {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
            {"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"},
            {"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'},
            {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
        ]

        async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
            for event in mock_events:
                yield event

        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        monkeypatch.setattr(client.http_service, "post_run", mock_post_run)

        messages = [ChatMessage(role="user", text="Test server tool execution")]
        chat_options = ChatOptions()

        updates: list[ChatResponseUpdate] = []
        async for update in client.get_streaming_response(messages, chat_options=chat_options):
            updates.append(update)

        function_calls = [
            content for update in updates for content in update.contents if isinstance(content, FunctionCallContent)
        ]
        assert function_calls
        assert function_calls[0].name == "get_time_zone"
        assert not any(
            isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents
        )

    async def test_server_tool_calls_not_executed_locally(self, monkeypatch: MonkeyPatch) -> None:
        """Server tools should not trigger local function invocation even when client tools exist."""

        @ai_function
        def client_tool() -> str:
            """Client tool stub."""
            return "client"

        mock_events = [
            {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
            {"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"},
            {"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'},
            {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
        ]

        async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
            for event in mock_events:
                yield event

        async def fake_auto_invoke(*args: object, **kwargs: Any) -> None:
            function_call = kwargs.get("function_call_content") or args[0]
            raise AssertionError(f"Unexpected local execution of server tool: {getattr(function_call, 'name', '?')}")

        monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke)

        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        monkeypatch.setattr(client.http_service, "post_run", mock_post_run)

        messages = [ChatMessage(role="user", text="Test server tool execution")]
        chat_options = ChatOptions(tool_choice="auto", tools=[client_tool])

        async for _ in client.get_streaming_response(messages, chat_options=chat_options):
            pass

    async def test_state_transmission(self, monkeypatch: MonkeyPatch) -> None:
        """Test state is properly transmitted to server."""
        import base64

        state_data = {"user_id": "123", "session": "abc"}
        state_json = json.dumps(state_data)
        state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")

        from agent_framework import DataContent

        messages = [
            ChatMessage(role="user", text="Hello"),
            ChatMessage(
                role="user",
                contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
            ),
        ]

        mock_events = [
            {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
            {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
        ]

        async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
            assert kwargs.get("state") == state_data
            for event in mock_events:
                yield event

        client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
        monkeypatch.setattr(client.http_service, "post_run", mock_post_run)

        chat_options = ChatOptions()

        response = await client.inner_get_response(messages=messages, chat_options=chat_options)

        assert response is not None
